{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## FuseNet with Classification Head - Loss and Accuracy Evaluation\n",
    "<br>\n",
    "<img src=\"./images/framework_class.jpg\" style=\"width:80%;height:80%;margin:0px 45px\">\n",
    "<p> This notebook provides a simple script to visualize and evaluate the performance of the trained FuseNet models. With this notebook it is possible to evaluate FuseNet models with or without scene-classification head.</p> \n",
    "<p><b>Note:</b> Please don't forget to include 'class' word in the name of the checkpoints of the models with classification head. Distinction between two model types is done by the name of the checkpoints.</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from utils.data_utils import get_data\n",
    "\n",
    "# set default size of plots\n",
    "%matplotlib inline\n",
    "plt.rcParams['figure.figsize'] = (10.0, 8.0) \n",
    "\n",
    "# for auto-reloading external modules\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Plot the loss, training and validation accuracy \n",
    "\n",
    "During training sessions, training and validation accuracy is saved for the whole sets at the end of each epoch. Accuracy calculation is done globally, i.e. for semantic segmentation, global_accuracy = correctly_labelled_pixels / all_labelled_pixels, and for scene classification; global_accuracy = correctly_classified_scenes / all_scenes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "while True:\n",
    "    data = input('[INPUT] Enter checkpoint path please: ')\n",
    "    model_path = data\n",
    "    break\n",
    "\n",
    "if os.path.isfile(model_path):\n",
    "    print('[PROGRESS] Loading checkpoint: {}   '.format(model_path), end='', flush=True)\n",
    "    checkpoint = torch.load(model_path)\n",
    "    print('\\r[INFO] Checkpoint loaded: {}'.format(model_path))\n",
    "    \n",
    "    use_class = model_path.lower().find('class') is not -1\n",
    "    loss_plot_num = 1\n",
    "    \n",
    "    trained_epoch = checkpoint['epoch']\n",
    "    best_val_acc = checkpoint['best_val_seg_acc']\n",
    "    train_loss_history = checkpoint['train_loss_hist']\n",
    "    val_seg_acc_history = checkpoint['val_seg_acc_hist']\n",
    "    train_seg_acc_history = checkpoint['train_seg_acc_hist']\n",
    "    \n",
    "    if use_class:\n",
    "        train_seg_loss_history = checkpoint['train_seg_loss_hist']\n",
    "        train_class_loss_history = checkpoint['train_class_loss_hist']\n",
    "        train_class_acc_history = checkpoint['train_class_acc_hist']\n",
    "        val_class_acc_history = checkpoint['val_class_acc_hist']\n",
    "        loss_plot_num = 3\n",
    "else:   \n",
    "    raise FileNotFoundError('No checkpoint found at {}'.format(model_path))\n",
    "\n",
    "plt.subplot(loss_plot_num, 1, 1)\n",
    "plt.plot(train_loss_history, 'o')\n",
    "\n",
    "if use_class:    \n",
    "    plt.plot(train_seg_loss_history, 'o')\n",
    "    plt.plot(train_class_loss_history, 'o')\n",
    "    plt.legend(['glob_loss', 'seg_loss', 'class_loss'], loc='upper left')\n",
    "else:\n",
    "    plt.legend(['seg_loss'], loc='upper left')\n",
    "    \n",
    "plt.xlabel('epoch')\n",
    "plt.ylabel('loss')\n",
    "plt.show()\n",
    "\n",
    "print('Minimum loss values achieved during training:')\n",
    "print('Global: %.3f at epoch: %i' % (np.min(train_loss_history), np.argmin(train_loss_history)))\n",
    "if use_class:\n",
    "    print('Segmentation: %.3f at epoch: %i' % (np.min(train_seg_loss_history), np.argmin(train_seg_loss_history)))\n",
    "    print('Classification: %.3f at epoch: %i \\n' % (np.min(train_class_loss_history), np.argmin(train_class_loss_history)))\n",
    "\n",
    "plt.subplot(3, 1, 2)\n",
    "plt.plot(train_seg_acc_history, '-o')\n",
    "plt.plot(val_seg_acc_history, '-o')\n",
    "plt.legend(['train', 'val'], loc='upper left')\n",
    "plt.xlabel('epoch')\n",
    "plt.ylabel('seg. accuracy')\n",
    "plt.show()\n",
    "\n",
    "if use_class:\n",
    "    plt.subplot(3, 1, 3)\n",
    "    plt.plot(train_class_acc_history, '-o')\n",
    "    plt.plot(val_class_acc_history, '-o')\n",
    "    plt.legend(['train', 'val'], loc='upper left')\n",
    "    plt.xlabel('epoch')\n",
    "    plt.ylabel('class. accuracy')\n",
    "    plt.show()\n",
    "\n",
    "print('Best accuracy values achieved during training (on validation set):')\n",
    "print('Global pixel-wise classification: %.3f achieved on epoch: %i' % (np.max(val_seg_acc_history), np.argmax(val_seg_acc_history)))\n",
    "if use_class:\n",
    "    print('Global scene classification: %.3f achieved on epoch: %i' % (np.max(val_class_acc_history), np.argmax(val_class_acc_history)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
